{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install --user graphviz" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Lab 12 - Decision Trees\n", "\n", "For this lab, we will use survey data collected by the city of [Somerville, MA](https://en.wikipedia.org/wiki/Somerville,_Massachusetts) asking residents about their happiness, as well as ratings of city services. \n", "\n", "The data is available from the UC Irvine Machine Learning Repository: [https://archive.ics.uci.edu/ml/datasets/Somerville+Happiness+Survey](https://archive.ics.uci.edu/ml/datasets/Somerville+Happiness+Survey)\n", "\n", "The link to download the data is [https://archive.ics.uci.edu/ml/machine-learning-databases/00479/SomervilleHappinessSurvey2015.csv](https://archive.ics.uci.edu/ml/machine-learning-databases/00479/SomervilleHappinessSurvey2015.csv)\n", "\n", "The data columns are:\n", "\n", "- D = decision attribute (D) with values 0 (unhappy) and 1 (happy) \n", "- X1 = the availability of information about the city services \n", "- X2 = the cost of housing \n", "- X3 = the overall quality of public schools \n", "- X4 = your trust in the local police \n", "- X5 = the maintenance of streets and sidewalks \n", "- X6 = the availability of social community events \n", "\n", "Attributes X1 to X6 have values 1 to 5." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn import tree\n", "import graphviz\n", "from graphviz import Source\n", " \n", "from sklearn.tree import export_graphviz\n", "import sklearn.metrics as met\n", "from sklearn.metrics import confusion_matrix\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Read the data into a dataframe. We have given the columns more descriptive names." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DX1X2X3X4X5X6
00333424
10323543
21533335
30543335
40543335
51553555
60312213
71544445
80414444
90444255
100323323
110443444
121524555
130424543
140413343
151324344
160534545
171514345
180512445
190424444
201423344
211423344
220435554
230435554
241512524
251433334
260343323
271333555
281331334
291331334
........................
1131534343
1140523335
1151434434
1160322333
1170413535
1181514355
1190413244
1201515355
1210424444
1220524455
1231534444
1241524423
1250533445
1260533444
1270323354
1280413334
1290514445
1300522445
1310535455
1321344513
1331515555
1341433444
1351551151
1360444413
1371523443
1380533135
1391523425
1401533445
1410433445
1420532555
\n", "

143 rows × 7 columns

\n", "
" ], "text/plain": [ " D X1 X2 X3 X4 X5 X6\n", "0 0 3 3 3 4 2 4\n", "1 0 3 2 3 5 4 3\n", "2 1 5 3 3 3 3 5\n", "3 0 5 4 3 3 3 5\n", "4 0 5 4 3 3 3 5\n", "5 1 5 5 3 5 5 5\n", "6 0 3 1 2 2 1 3\n", "7 1 5 4 4 4 4 5\n", "8 0 4 1 4 4 4 4\n", "9 0 4 4 4 2 5 5\n", "10 0 3 2 3 3 2 3\n", "11 0 4 4 3 4 4 4\n", "12 1 5 2 4 5 5 5\n", "13 0 4 2 4 5 4 3\n", "14 0 4 1 3 3 4 3\n", "15 1 3 2 4 3 4 4\n", "16 0 5 3 4 5 4 5\n", "17 1 5 1 4 3 4 5\n", "18 0 5 1 2 4 4 5\n", "19 0 4 2 4 4 4 4\n", "20 1 4 2 3 3 4 4\n", "21 1 4 2 3 3 4 4\n", "22 0 4 3 5 5 5 4\n", "23 0 4 3 5 5 5 4\n", "24 1 5 1 2 5 2 4\n", "25 1 4 3 3 3 3 4\n", "26 0 3 4 3 3 2 3\n", "27 1 3 3 3 5 5 5\n", "28 1 3 3 1 3 3 4\n", "29 1 3 3 1 3 3 4\n", ".. .. .. .. .. .. .. ..\n", "113 1 5 3 4 3 4 3\n", "114 0 5 2 3 3 3 5\n", "115 1 4 3 4 4 3 4\n", "116 0 3 2 2 3 3 3\n", "117 0 4 1 3 5 3 5\n", "118 1 5 1 4 3 5 5\n", "119 0 4 1 3 2 4 4\n", "120 1 5 1 5 3 5 5\n", "121 0 4 2 4 4 4 4\n", "122 0 5 2 4 4 5 5\n", "123 1 5 3 4 4 4 4\n", "124 1 5 2 4 4 2 3\n", "125 0 5 3 3 4 4 5\n", "126 0 5 3 3 4 4 4\n", "127 0 3 2 3 3 5 4\n", "128 0 4 1 3 3 3 4\n", "129 0 5 1 4 4 4 5\n", "130 0 5 2 2 4 4 5\n", "131 0 5 3 5 4 5 5\n", "132 1 3 4 4 5 1 3\n", "133 1 5 1 5 5 5 5\n", "134 1 4 3 3 4 4 4\n", "135 1 5 5 1 1 5 1\n", "136 0 4 4 4 4 1 3\n", "137 1 5 2 3 4 4 3\n", "138 0 5 3 3 1 3 5\n", "139 1 5 2 3 4 2 5\n", "140 1 5 3 3 4 4 5\n", "141 0 4 3 3 4 4 5\n", "142 0 5 3 2 5 5 5\n", "\n", "[143 rows x 7 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.read_csv(\"SomervilleHappinessSurvey2015.csv\", encoding = \"utf-16le\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
happycity_infohousing_costschool_qualitytrust_policestreets_sidewalkscommunity_events
00333424
10323543
21533335
30543335
40543335
\n", "
" ], "text/plain": [ " happy city_info housing_cost school_quality trust_police \\\n", "0 0 3 3 3 4 \n", "1 0 3 2 3 5 \n", "2 1 5 3 3 3 \n", "3 0 5 4 3 3 \n", "4 0 5 4 3 3 \n", "\n", " streets_sidewalks community_events \n", "0 2 4 \n", "1 4 3 \n", "2 3 5 \n", "3 3 5 \n", "4 3 5 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "new_column_names = [\"happy\",\"city_info\",\"housing_cost\", \"school_quality\", \\\n", " \"trust_police\", \"streets_sidewalks\", \"community_events\"]\n", "city = pd.read_csv(\"SomervilleHappinessSurvey2015.csv\", \\\n", " encoding = \"utf-16le\",names = new_column_names, \\\n", " header = 0)\n", "city.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classwork\n", "\n", "The code belows allows you to make your own decision tree. What three conditions should you use to get the highest accuracy?" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# top level of decision tree\n", "filter_level_1 = city[\"school_quality\"] < 4\n", "level_2_left = city[filter_level_1]\n", "level_2_right = city[~filter_level_1]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# second level of decision tree on left\n", "filter_level_2_left = level_2_left[\"housing_cost\"] < 4\n", "level_3_left_left = level_2_left[filter_level_2_left]\n", "level_3_left_right = level_2_left[~filter_level_2_left]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# second level of decision tree on right\n", "filter_level_2_right = level_2_right[\"community_events\"] < 4\n", "level_3_right_left = level_2_right[filter_level_2_right]\n", "level_3_right_right = level_2_right[~filter_level_2_right]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sensitivity: 0.4675324675324675\n", "Specificity: 0.7424242424242424\n", "Precision: 0.6792452830188679\n", "Accuracy: 0.5944055944055944\n" ] } ], "source": [ "# make predictions\n", "\n", "proportion_1 = level_3_left_left[\"happy\"].sum()/level_3_left_left.shape[0]\n", "if (proportion_1 >= 0.5):\n", " confusion_matrix_left_left = confusion_matrix(level_3_left_left[\"happy\"],np.ones(level_3_left_left.shape[0]))\n", "else:\n", " confusion_matrix_left_left = confusion_matrix(level_3_left_left[\"happy\"],np.zeros(level_3_left_left.shape[0]))\n", "\n", "proportion_1 = level_3_left_right[\"happy\"].sum()/level_3_left_right.shape[0]\n", "if (proportion_1 >= 0.5):\n", " confusion_matrix_left_right = confusion_matrix(level_3_left_right[\"happy\"],np.ones(level_3_left_right.shape[0]))\n", "else:\n", " confusion_matrix_left_right = confusion_matrix(level_3_left_right[\"happy\"],np.zeros(level_3_left_right.shape[0]))\n", "\n", "proportion_1 = level_3_right_left[\"happy\"].sum()/level_3_right_left.shape[0]\n", "if (proportion_1 >= 0.5):\n", " confusion_matrix_right_left = confusion_matrix(level_3_right_left[\"happy\"],np.ones(level_3_right_left.shape[0]))\n", "else:\n", " confusion_matrix_right_left = confusion_matrix(level_3_right_left[\"happy\"],np.zeros(level_3_right_left.shape[0]))\n", "\n", "\n", "proportion_1 = level_3_right_right[\"happy\"].sum()/level_3_right_right.shape[0]\n", "if (proportion_1 >= 0.5):\n", " confusion_matrix_right_right = confusion_matrix(level_3_right_right[\"happy\"],np.ones(level_3_right_right.shape[0]))\n", "else:\n", " confusion_matrix_right_right = confusion_matrix(level_3_right_right[\"happy\"],np.zeros(level_3_right_right.shape[0]))\n", "\n", "cm = confusion_matrix_left_left + confusion_matrix_left_right + confusion_matrix_right_left + \\\n", " confusion_matrix_right_right\n", "\n", "tn, fp, fn, tp = cm.ravel()\n", "\n", "sensitivity = tp/(tp + fn)\n", "specificity = tn/(tn + fp)\n", "precision = tp/(tp + fp)\n", "accuracy = (tp + tn)/(tp + tn + fp + fn)\n", "\n", "print(\"Sensitivity:\",sensitivity)\n", "print(\"Specificity:\",specificity)\n", "print(\"Precision:\", precision)\n", "print(\"Accuracy:\",accuracy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fitting a decision tree with sci-kit learn\n", "\n", "We can get just the independent variables (x's) using the following:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
city_infohousing_costschool_qualitytrust_policestreets_sidewalkscommunity_events
0333424
1323543
2533335
3543335
4543335
5553555
6312213
7544445
8414444
9444255
10323323
11443444
12524555
13424543
14413343
15324344
16534545
17514345
18512445
19424444
20423344
21423344
22435554
23435554
24512524
25433334
26343323
27333555
28331334
29331334
.....................
113534343
114523335
115434434
116322333
117413535
118514355
119413244
120515355
121424444
122524455
123534444
124524423
125533445
126533444
127323354
128413334
129514445
130522445
131535455
132344513
133515555
134433444
135551151
136444413
137523443
138533135
139523425
140533445
141433445
142532555
\n", "

143 rows × 6 columns

\n", "
" ], "text/plain": [ " city_info housing_cost school_quality trust_police streets_sidewalks \\\n", "0 3 3 3 4 2 \n", "1 3 2 3 5 4 \n", "2 5 3 3 3 3 \n", "3 5 4 3 3 3 \n", "4 5 4 3 3 3 \n", "5 5 5 3 5 5 \n", "6 3 1 2 2 1 \n", "7 5 4 4 4 4 \n", "8 4 1 4 4 4 \n", "9 4 4 4 2 5 \n", "10 3 2 3 3 2 \n", "11 4 4 3 4 4 \n", "12 5 2 4 5 5 \n", "13 4 2 4 5 4 \n", "14 4 1 3 3 4 \n", "15 3 2 4 3 4 \n", "16 5 3 4 5 4 \n", "17 5 1 4 3 4 \n", "18 5 1 2 4 4 \n", "19 4 2 4 4 4 \n", "20 4 2 3 3 4 \n", "21 4 2 3 3 4 \n", "22 4 3 5 5 5 \n", "23 4 3 5 5 5 \n", "24 5 1 2 5 2 \n", "25 4 3 3 3 3 \n", "26 3 4 3 3 2 \n", "27 3 3 3 5 5 \n", "28 3 3 1 3 3 \n", "29 3 3 1 3 3 \n", ".. ... ... ... ... ... \n", "113 5 3 4 3 4 \n", "114 5 2 3 3 3 \n", "115 4 3 4 4 3 \n", "116 3 2 2 3 3 \n", "117 4 1 3 5 3 \n", "118 5 1 4 3 5 \n", "119 4 1 3 2 4 \n", "120 5 1 5 3 5 \n", "121 4 2 4 4 4 \n", "122 5 2 4 4 5 \n", "123 5 3 4 4 4 \n", "124 5 2 4 4 2 \n", "125 5 3 3 4 4 \n", "126 5 3 3 4 4 \n", "127 3 2 3 3 5 \n", "128 4 1 3 3 3 \n", "129 5 1 4 4 4 \n", "130 5 2 2 4 4 \n", "131 5 3 5 4 5 \n", "132 3 4 4 5 1 \n", "133 5 1 5 5 5 \n", "134 4 3 3 4 4 \n", "135 5 5 1 1 5 \n", "136 4 4 4 4 1 \n", "137 5 2 3 4 4 \n", "138 5 3 3 1 3 \n", "139 5 2 3 4 2 \n", "140 5 3 3 4 4 \n", "141 4 3 3 4 4 \n", "142 5 3 2 5 5 \n", "\n", " community_events \n", "0 4 \n", "1 3 \n", "2 5 \n", "3 5 \n", "4 5 \n", "5 5 \n", "6 3 \n", "7 5 \n", "8 4 \n", "9 5 \n", "10 3 \n", "11 4 \n", "12 5 \n", "13 3 \n", "14 3 \n", "15 4 \n", "16 5 \n", "17 5 \n", "18 5 \n", "19 4 \n", "20 4 \n", "21 4 \n", "22 4 \n", "23 4 \n", "24 4 \n", "25 4 \n", "26 3 \n", "27 5 \n", "28 4 \n", "29 4 \n", ".. ... \n", "113 3 \n", "114 5 \n", "115 4 \n", "116 3 \n", "117 5 \n", "118 5 \n", "119 4 \n", "120 5 \n", "121 4 \n", "122 5 \n", "123 4 \n", "124 3 \n", "125 5 \n", "126 4 \n", "127 4 \n", "128 4 \n", "129 5 \n", "130 5 \n", "131 5 \n", "132 3 \n", "133 5 \n", "134 4 \n", "135 1 \n", "136 3 \n", "137 3 \n", "138 5 \n", "139 5 \n", "140 5 \n", "141 5 \n", "142 5 \n", "\n", "[143 rows x 6 columns]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "city.iloc[:,1:7]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we created the decision tree classifier variable (object) and then fit it to our data:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "clf = tree.DecisionTreeClassifier(max_depth = 2)\n", "clf = clf.fit(city.iloc[:,1:7], city[\"happy\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you are running Jupyter Hub on your own computer, you may be able to display the decision tree by:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "'module' object has no attribute 'plot_tree'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtree\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'plot_tree'" ] } ], "source": [ "tree.plot_tree(clf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you are using the Jupyter Hub server, run the following code (which will give an error):" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [ { "ename": "PermissionError", "evalue": "[Errno 13] Permission denied", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mPermissionError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdot_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexport_graphviz\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_file\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mgraph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgraphviz\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSource\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdot_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mgraph\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"happiness.dot\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/.local/lib/python3.4/site-packages/graphviz/files.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, filename, directory, view, cleanup, format, renderer, formatter)\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0mformat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_format\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 188\u001b[0;31m \u001b[0mrendered\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilepath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 189\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcleanup\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.local/lib/python3.4/site-packages/graphviz/backend.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(engine, format, filepath, renderer, formatter, quiet)\u001b[0m\n\u001b[1;32m 181\u001b[0m \"\"\"\n\u001b[1;32m 182\u001b[0m \u001b[0mcmd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrendered\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcommand\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mengine\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilepath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 183\u001b[0;31m \u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcmd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcapture_output\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcheck\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mquiet\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mquiet\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 184\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrendered\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.local/lib/python3.4/site-packages/graphviz/backend.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(cmd, input, capture_output, check, quiet, **kwargs)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m \u001b[0mproc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msubprocess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcmd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstartupinfo\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mget_startupinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mOSError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merrno\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0merrno\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mENOENT\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.4/subprocess.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, args, bufsize, executable, stdin, stdout, stderr, preexec_fn, close_fds, shell, cwd, env, universal_newlines, startupinfo, creationflags, restore_signals, start_new_session, pass_fds)\u001b[0m\n\u001b[1;32m 854\u001b[0m \u001b[0mc2pread\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc2pwrite\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 855\u001b[0m \u001b[0merrread\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merrwrite\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 856\u001b[0;31m restore_signals, start_new_session)\n\u001b[0m\u001b[1;32m 857\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 858\u001b[0m \u001b[0;31m# Cleanup if the child failed starting.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.4/subprocess.py\u001b[0m in \u001b[0;36m_execute_child\u001b[0;34m(self, args, executable, preexec_fn, close_fds, pass_fds, cwd, env, startupinfo, creationflags, shell, p2cread, p2cwrite, c2pread, c2pwrite, errread, errwrite, restore_signals, start_new_session)\u001b[0m\n\u001b[1;32m 1462\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1463\u001b[0m \u001b[0merr_msg\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m': '\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mrepr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morig_executable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1464\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mchild_exception_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merrno_num\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merr_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1465\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mchild_exception_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merr_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1466\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mPermissionError\u001b[0m: [Errno 13] Permission denied" ] } ], "source": [ "dot_data = tree.export_graphviz(clf, out_file=None) \n", "graph = graphviz.Source(dot_data) \n", "graph.render(\"happiness.dot\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, despite the error, there should now be a file called happiness.dot in your directory. To view the fitted decision tree, open the happiness.dot file in Jupyter and copy the text. Paste this text into the text box at [http://www.webgraphviz.com](http://www.webgraphviz.com) and click the \"Generate graph!\" button at the bottom.\n", "\n", "The column names have been replaced by `X[0], X[1], ..., X[5]`. Run the following code to change `X[0], X[1], ..., X[5]` to the column names in happiness.dot." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "with open (\"happiness.dot\", \"r\") as fin:\n", " with open(\"happiness_fixed.dot\",\"w\") as fout:\n", " for line in fin.readlines():\n", " line = line.replace(\"X[0]\",\"city_info\")\n", " line = line.replace(\"X[1]\",\"housing_cost\")\n", " line = line.replace(\"X[2]\",\"school_quality\")\n", " line = line.replace(\"X[3]\",\"trust_police\")\n", " line = line.replace(\"X[4]\",\"streets_sidewalks\")\n", " line = line.replace(\"X[5]\",\"community_events\")\n", " fout.write(line)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "66" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(city[\"happy\"] == 0).sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Copy the contents of happiness.dot into the textbox in [http://www.webgraphviz.com](http://www.webgraphviz.com) to display the decision tree with the column names. How does it compare the the decision tree you made?\n", "\n", "To make predictions, we can use the following code:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,\n", " 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1,\n", " 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0,\n", " 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0,\n", " 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1,\n", " 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions = clf.predict(city.iloc[:,1:7])\n", "predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We compute the confusion matrix:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[54, 12],\n", " [34, 43]])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "met.confusion_matrix(city[\"happy\"], predictions)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get the true negatives, false positives, false negatives, and true positives:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sensitivity: 0.5584415584415584\n", "Specificity: 0.8181818181818182\n", "Precision: 0.7818181818181819\n", "Accuracy: 0.6783216783216783\n" ] } ], "source": [ "tn, fp, fn, tp = met.confusion_matrix(city[\"happy\"], predictions).ravel()\n", "\n", "sensitivity = tp/(tp + fn)\n", "specificity = tn/(tn + fp)\n", "precision = tp/(tp + fp)\n", "accuracy = (tp + tn)/(tp + tn + fp + fn)\n", "\n", "print(\"Sensitivity:\",sensitivity)\n", "print(\"Specificity:\",specificity)\n", "print(\"Precision:\", precision)\n", "print(\"Accuracy:\",accuracy)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.4.8" } }, "nbformat": 4, "nbformat_minor": 2 }